import torch
from torch import nn
import torch.nn.functional as F

from tools.utils import kl_normal_log

from .PriorNet import PriorNet
from .PosteriorNet import PosteriorNet
from .GateMechanism import GTU
from .AddNoise import add_gaussian_noise, add_dropout_noise
from .AttentionLayer import Attention

class CausalHMM(nn.Module):
    def __init__(self,
                utter_dim,
                latent_dim=48,
                prior_type='GRU',
                posterior_type='FC',
                dropout_prob=0.1,
                mu_type='share',
                var_type='share',
                use_reparameterize=True,
                activation_type='relu',
                pred_z_init = 'origin',
                ):
        super(CausalHMM, self).__init__()
        
        self.utter_dim = utter_dim
        self.latent_dim = latent_dim
        self.var_type = var_type
        self.pred_z_init = pred_z_init
        self.activation = activation_type
        self.dropout_prob = dropout_prob
        
        self.PriorNet = PriorNet(
            utter_dim=self.utter_dim,
            latent_dim=self.latent_dim,
            prior_type=prior_type,
            var_type=var_type,
            mu_type=mu_type,
            activation=self.activation,
            dropout_prob=dropout_prob,
        )

        self.PosteriorNet = PosteriorNet(
            utter_dim=self.utter_dim,
            latent_dim=self.latent_dim,
            posterior_type=posterior_type,
            mu_type=mu_type,
            var_type=var_type,
            use_reparameterize=use_reparameterize,
            activation=self.activation,
            dropout_prob=dropout_prob,
        )
        
        self.prior_emotion_FC = nn.Linear(self.utter_dim, self.utter_dim)
        self.posterior_strategy_FC = nn.Linear(self.utter_dim, self.utter_dim)
    
        self.noise_weight_seeker = nn.Parameter(torch.zeros(1))
        self.noise_weight_supporter = nn.Parameter(torch.zeros(1))
        
        self.pr_gate = GTU()
        self.po_gate = GTU()
        
    # The core implementation of the code has been removed, and the full code will be released upon the paper's acceptance.    
    def dynamic_matching(self, prior_eps_mu_list, prior_eps_logvar_list, prior_z_mu_list, prior_z_logvar_list, posterior_mu_list, posterior_logvar_list, max_time):
        pass
        
        return po_eps_z_all_kl_loss
          
    def forward(self, strategy_embs, emotion_embs, seeker_tensors, supporter_tensors, seeker_mask, supporter_mask, test=False):
        
        for t in range(max_time):

            #seeker_tensors[active_seekers, t, :] batch维度会变化，原因是在不同的时间步选择了不同的数量的活跃批次
            if active_seekers.any():
                prior_emotion_features = self.prior_emotion_FC(emotion_embs[:, t, :])
                pr_eps_mu_cur, pr_z_mu, pr_z_logvar, pr_eps_mu, pr_eps_logvar, causal_loss = self.PriorNet(
    prior_emotion_features, seeker_tensors[:, t, :], pr_hidden_last, po_hidden_last, mask=active_seekers)

            if active_supporters.any():
                posterior_strategy_features = self.posterior_strategy_FC(strategy_embs[:, t, :]) 
                po_eps_mu, po_eps_logvar, po_z_cur, rec_loss = self.PosteriorNet(
                    posterior_strategy_features, supporter_tensors[:, t, :], po_hidden_last, mask=active_supporters,  test=test)

                po_rec_all_loss += rec_loss
                pr_causal_loss += causal_loss
                    
        #根据记录 最后活跃时间步 更新最终潜在表示
        for idx in range(batch):
            if last_active_time_seeker[idx] != -1:
                final_pr_eps_mu[idx] = updated_pr_z_mu[idx, last_active_time_seeker[idx]]
            if last_active_time_supporter[idx] != -1:
                final_po_eps_mu[idx] = updated_po_eps_mu[idx, last_active_time_supporter[idx]]

        final_eps_mu = torch.cat((final_pr_eps_mu, final_po_eps_mu), dim=1)
        
        po_eps_z_all_kl_loss = self.dynamic_matching(prior_eps_mu_list, prior_eps_logvar_list, prior_z_mu_list, prior_z_logvar_list, posterior_mu_list, posterior_logvar_list, max_time)
        
        pr_causal_loss = pr_causal_loss / max_time
        po_rec_all_loss = po_rec_all_loss / max_time

        return final_eps_mu, po_eps_z_all_kl_loss.double(), po_rec_all_loss, causal_loss